package server

import (
	"fmt"
	"sync"

	"github.com/pingcap/tidb/mysql"
	"github.com/zeast/logs"
)

var (
	_defaultBacklog = int64(1024 * 1024) //默认的最大连接数100万
)

//Counter count the client connection and limit
type Counter interface {
	SetMax(int64) //set the max count.
	Max() int64   // the max value
	Size() int64  //get the current size of counter.
	Incr() bool   //incr will block when out of max count.
	Decr()
}

var _ Counter = &ChanCount{}

//ChanCount chan count will bolck when at max.
type ChanCount struct {
	max int64
	ch  chan struct{}
}

//SetMax max
func (c *ChanCount) SetMax(m int64) {
	if m <= 0 {
		m = _defaultBacklog
	}
	c.max = m
	if c.ch == nil {
		c.ch = make(chan struct{}, c.max)
	}
}

//Incr incr
func (c *ChanCount) Incr() bool {
	select {
	case c.ch <- struct{}{}:
		return true
	default:
		return false
	}
}

//Decr decr
func (c *ChanCount) Decr() { <-c.ch }

//Size size
func (c *ChanCount) Size() int64 {
	return int64(len(c.ch))
}

//Max return the max
func (c *ChanCount) Max() int64 {
	return int64(cap(c.ch))
}

//IntCount int count will when return false when at max.
type IntCount struct {
	sync.Mutex
	max int64
	cur int64
}

var _ Counter = &IntCount{}

//SetMax max
func (c *IntCount) SetMax(m int64) {
	if m <= 0 {
		m = _defaultBacklog
	}
	c.Lock()
	c.max = m
	c.Unlock()
}

//Incr incr
func (c *IntCount) Incr() bool {
	c.Lock()
	if c.cur > c.max {
		c.Unlock()
		return false
	}
	c.cur++
	c.Unlock()
	return true
}

//Decr decrs
func (c *IntCount) Decr() {
	c.Lock()
	c.cur--
	c.Unlock()
}

//Size size
func (c *IntCount) Size() (s int64) {
	c.Lock()
	s = c.cur
	c.Unlock()
	return
}

//Max return the max
func (c *IntCount) Max() (s int64) {
	c.Lock()
	s = c.max
	c.Unlock()
	return
}

//CountMgr the collection of counters
var CountMgr *countMgr

type countMgr struct {
	sync.RWMutex
	m map[string]Counter //name -> counter
}

//InitCountMgr init the global counter. username => maxConn
func InitCountMgr(cfg map[string]int, backlog int64) {
	cm := &countMgr{
		m: make(map[string]Counter),
	}
	if backlog <= 0 {
		backlog = _defaultBacklog
	}
	for k, v := range cfg {
		c := new(IntCount)
		c.SetMax(backlog + int64(v))
		cm.m[k] = c
		//log.Debug(k, c)
	}
	CountMgr = cm
}

//Add add a user to countMgr; user is username
func (cm *countMgr) NotifyNew(user string, n int64) {
	cm.RLock()
	_, ok := cm.m[user]
	cm.RUnlock()
	if ok {
		return
	}
	c := new(IntCount)
	c.SetMax(_defaultBacklog)
	cm.Lock()
	cm.m[user] = c
	cm.Unlock()
}

func (cm *countMgr) GetCounters() map[string]Counter {
	return cm.m
}

func (cm *countMgr) GetSize(key string) int64 {
	cm.RLock()
	c := cm.m[key]
	cm.RUnlock()
	if c == nil {
		logs.Errorf("GetSize not found counter: %s", key)
		return 0
	}
	return c.Size()
}

func (cm *countMgr) Incr(key string) error {
	cm.RLock()
	c := cm.m[key]
	cm.RUnlock()
	if c == nil {
		return mysql.NewErrf(mysql.ErrUsername, `User %s does not exsit`, key)
	}
	ok := c.Incr()
	if !ok {
		return mysql.NewErr(mysql.ErrTooManyUserConnections, key)
	}
	return nil
}

func (cm *countMgr) Decr(key string) error {
	cm.RLock()
	c, ok := cm.m[key]
	cm.RUnlock()
	if !ok || c == nil {
		return fmt.Errorf(`User %s does not exsit`, key)
	}
	c.Decr()
	return nil
}
